import os

REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:56379/0")
print(f"正在连接到Redis: {REDIS_URL}")

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate, SystemMessagePromptTemplate, \
    HumanMessagePromptTemplate
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_ollama import ChatOllama
from langchain_redis import RedisChatMessageHistory


# 获取session_id对应的ChatMessageHistory对象
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    return RedisChatMessageHistory(session_id, redis_url=REDIS_URL)


prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template("你是一个擅长{ability}的助手。回答不超过20个字。尽量用中文回复。"),
        MessagesPlaceholder(variable_name="history"),
        HumanMessagePromptTemplate.from_template("{input}")
    ]
)

model = ChatOllama(
    model="llama3",
    temperature=0.5,
)

runnable = prompt | model

# 构建自动记录聊天消息的RunnableWithMessageHistory对象
with_message_history = RunnableWithMessageHistory(
    runnable,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history",
)

# 问1
print(with_message_history.invoke(
    {"ability": "数学", "input": "余弦函数是什么意思？"},
    config={"configurable": {"session_id": "aaa"}},
))

# 问2
print(with_message_history.invoke(
    {"ability": "数学", "input": "什么,你刚刚说什么？"},
    config={"configurable": {"session_id": "aaa"}},
))

# 问3
print(with_message_history.invoke(
    {"ability": "数学", "input": "麻烦用最简洁的语言重新介绍？"},
    config={"configurable": {"session_id": "aaa"}},
))
